import os, sys
import torch
from torch import nn
import numpy as np

def getgrad(model:torch.nn.Module, grad_dict:dict, step_iter=0):
    if step_iter==0:
        for name,mod in model.named_modules():
            if isinstance(mod, nn.Conv2d) or isinstance(mod, nn.Linear):
                if mod.weight.grad is not None:
                    # print(mod.weight.grad.data.size())
                    # print(mod.weight.data.size())
                    grad_dict[name]=[mod.weight.grad.data.cpu().reshape(-1).numpy()]
    else:
        for name,mod in model.named_modules():
            if isinstance(mod, nn.Conv2d) or isinstance(mod, nn.Linear):
                if mod.weight.grad is not None:
                    grad_dict[name].append(mod.weight.grad.data.cpu().reshape( -1).numpy())
    return grad_dict

def caculate_zico(grad_dict):
    allgrad_array=None
    for i, modname in enumerate(grad_dict.keys()):
        grad_dict[modname]= np.array(grad_dict[modname])
    nsr_mean_sum = 0
    nsr_mean_sum_abs = 0
    nsr_mean_avg = 0
    nsr_mean_avg_abs = 0
    for j, modname in enumerate(grad_dict.keys()):
        nsr_std = np.std(grad_dict[modname], axis=0)
        nonzero_idx = np.nonzero(nsr_std)[0]
        nsr_mean_abs = np.mean(np.abs(grad_dict[modname]), axis=0)
        tmpsum = np.sum(nsr_mean_abs[nonzero_idx]/nsr_std[nonzero_idx])
        if tmpsum==0:
            pass
        else:
            nsr_mean_sum_abs += np.log(tmpsum)
            nsr_mean_avg_abs += np.log(np.mean(nsr_mean_abs[nonzero_idx]/nsr_std[nonzero_idx]))
    return nsr_mean_sum_abs


def zico(train_loader, networks, batch_iter=2, train_mode=False, num_classes=100,  verbose=False):
    device = torch.cuda.current_device()
    
    grad_dict= {}
    for network in networks:
        network.train()
    lossfunc = nn.CrossEntropyLoss().cuda()
    
    zicos = []
    for network in networks:
        network.to(device)
        for i, (inputs, targets) in enumerate(train_loader):
            if i == batch_iter:
                break
            network.zero_grad(set_to_none=True)
            data, label = inputs, targets
            data, label = data.to(device), label.to(device)

            _, logits = network(data)
            loss = lossfunc(logits, label)
            loss.backward()
            grad_dict= getgrad(network, grad_dict, i)
            
        res = caculate_zico(grad_dict)
        zicos.append(-1 * res)
    
    return zicos